//go:build linux || (windows && amd64)
// +build linux windows,amd64

/*
** Copyright (C) 2001-2025 Zabbix SIA
**
** This program is free software: you can redistribute it and/or modify it under the terms of
** the GNU Affero General Public License as published by the Free Software Foundation, version 3.
**
** This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
** without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
** See the GNU Affero General Public License for more details.
**
** You should have received a copy of the GNU Affero General Public License along with this program.
** If not, see <https://www.gnu.org/licenses/>.
**/

package tcpudp

import (
	"errors"
	"net"
	"os"
	"strconv"

	"github.com/sokurenko/go-netstat/netstat"
)

// exportNetTcpSocketCount - returns number of TCP sockets that match parameters.
func (p *Plugin) exportNetTcpSocketCount(params []string) (result int, err error) {
	if len(params) > 5 {
		return 0, errors.New(errorTooManyParams)
	}

	var laddres net.IP
	var lNet *net.IPNet

	if len(params) > 0 && len(params[0]) > 0 {
		if ip := net.ParseIP(params[0]); ip != nil {
			laddres = ip
		} else if _, lNet, err = net.ParseCIDR(params[0]); err != nil {
			return 0, errors.New(errorInvalidFirstParam)
		}
	}

	lport := 0
	if len(params) > 1 && len(params[1]) > 0 {
		if port, err := strconv.ParseUint(params[1], 10, 16); err != nil {
			if lport, err = net.LookupPort("tcp", params[1]); err != nil {
				if lport, err = net.LookupPort("tcp6", params[1]); err != nil {
					return 0, errors.New(errorInvalidSecondParam)
				}
			}
		} else {
			lport = int(port)
		}
	}

	var raddres net.IP
	var rNet *net.IPNet

	if len(params) > 2 && len(params[2]) > 0 {
		if ip := net.ParseIP(params[2]); ip != nil {
			raddres = ip
		} else if _, rNet, err = net.ParseCIDR(params[2]); err != nil {
			return 0, errors.New(errorInvalidThirdParam)
		}
	}

	rport := 0
	if len(params) > 3 && len(params[3]) > 0 {
		if port, err := strconv.ParseUint(params[3], 10, 16); err != nil {
			if rport, err = net.LookupPort("tcp", params[3]); err != nil {
				if rport, err = net.LookupPort("tcp6", params[3]); err != nil {
					return 0, errors.New(errorInvalidFourthParam)
				}
			}
		} else {
			rport = int(port)
		}
	}

	var state netstat.SkState

	if len(params) > 4 && len(params[4]) > 0 {
		switch params[4] {
		case "established":
			state = netstat.Established
		case "syn_sent":
			state = netstat.SynSent
		case "syn_recv":
			state = netstat.SynRecv
		case "fin_wait1":
			state = netstat.FinWait1
		case "fin_wait2":
			state = netstat.FinWait2
		case "time_wait":
			state = netstat.TimeWait
		case "close":
			state = netstat.Close
		case "close_wait":
			state = netstat.CloseWait
		case "last_ack":
			state = netstat.LastAck
		case "listen":
			state = netstat.Listen
		case "closing":
			state = netstat.Closing
		default:
			return 0, errors.New(errorInvalidFifthParam)
		}
	}

	return netStatTcpCount(laddres, lNet, lport, raddres, rNet, rport, state)
}

// netStatTcpCount - returns number of TCP sockets that match parameters.
func netStatTcpCount(laddres net.IP, lNet *net.IPNet, lport int, raddres net.IP, rNet *net.IPNet, rport int,
	state netstat.SkState) (result int, err error) {
	count := 0

	_, err = netstat.TCPSocks(func(s *netstat.SockTabEntry) bool {
		if state != 0 && s.State != state {
			return false
		}
		if lport != 0 && s.LocalAddr.Port != uint16(lport) {
			return false
		}
		if laddres != nil && !s.LocalAddr.IP.Equal(laddres) {
			return false
		}
		if lNet != nil && !lNet.Contains(s.LocalAddr.IP) {
			return false
		}
		if rport != 0 && s.RemoteAddr.Port != uint16(rport) {
			return false
		}
		if raddres != nil && !s.RemoteAddr.IP.Equal(raddres) {
			return false
		}
		if rNet != nil && !rNet.Contains(s.RemoteAddr.IP) {
			return false
		}

		count++
		return false
	})
	if err != nil {
		return 0, err
	}

	_, err = netstat.TCP6Socks(func(s *netstat.SockTabEntry) bool {
		if state != 0 && s.State != state {
			return false
		}
		if lport != 0 && s.LocalAddr.Port != uint16(lport) {
			return false
		}
		if laddres != nil && !s.LocalAddr.IP.Equal(laddres) {
			return false
		}
		if lNet != nil && !lNet.Contains(s.LocalAddr.IP) {
			return false
		}
		if rport != 0 && s.RemoteAddr.Port != uint16(rport) {
			return false
		}
		if raddres != nil && !s.RemoteAddr.IP.Equal(raddres) {
			return false
		}
		if rNet != nil && !rNet.Contains(s.RemoteAddr.IP) {
			return false
		}
		count++
		return false
	})

	if err != nil && !errors.Is(err, os.ErrNotExist) {
		return 0, err
	}

	return count, nil
}
